nlp_architect.nn.torch package


nlp_architect.nn.torch.distillation module

class nlp_architect.nn.torch.distillation.TeacherStudentDistill(teacher_model: nlp_architect.models.TrainableModel, temperature: float = 1.0, dist_w: float = 0.1, loss_w: float = 1.0, loss_function='kl')[source]

Bases: object

Teacher-Student knowledge distillation helper. Use this object when training a model with KD and a teacher model.

  • teacher_model (TrainableModel) – teacher model
  • temperature (float, optional) – KD temperature. Defaults to 1.0.
  • dist_w (float, optional) – distillation loss weight. Defaults to 0.1.
  • loss_w (float, optional) – student loss weight. Defaults to 1.0.
  • loss_function (str, optional) – loss function to use (kl for KLDivLoss, mse for MSELoss)
static add_args(parser: argparse.ArgumentParser)[source]

Add KD arguments to parser

Parameters:parser (argparse.ArgumentParser) – parser
distill_loss(loss, student_logits, teacher_logits)[source]

Add KD loss

  • loss – student loss
  • student_logits – student model logits
  • teacher_logits – teacher model logits

KD loss


Get teacher logits

Parameters:inputs – input
Returns:teachr logits

nlp_architect.nn.torch.quantization module

Quantization ops

class nlp_architect.nn.torch.quantization.FakeLinearQuantizationWithSTE[source]

Bases: torch.autograd.function.Function

Simulates error caused by quantization. Uses Straight-Through Estimator for Back prop

static backward(ctx, grad_output)[source]

Calculate estimated gradients for fake quantization using Straight-Through Estimator (STE) according to:

static forward(ctx, input, scale, bits=8)[source]

fake quantize input according to scale and number of bits, dequantize quantize(input))

class nlp_architect.nn.torch.quantization.QuantizationConfig(**kwargs)[source]

Bases: nlp_architect.common.config.Config

Quantization Configuration Object

ATTRIBUTES = {'activation_bits': 8, 'ema_decay': 0.9999, 'mode': 'none', 'requantize_output': True, 'start_step': 0, 'weight_bits': 8}
class nlp_architect.nn.torch.quantization.QuantizationMode[source]

Bases: enum.Enum

An enumeration.

EMA = 3
NONE = 1
class nlp_architect.nn.torch.quantization.QuantizedEmbedding(*args, weight_bits=8, start_step=0, mode='none', **kwargs)[source]

Bases: nlp_architect.nn.torch.quantization.QuantizedLayer, torch.nn.modules.sparse.Embedding

Embedding layer with quantization aware training capability


forward to be used during inference


Return quantized embeddings

class nlp_architect.nn.torch.quantization.QuantizedLayer(*args, weight_bits=8, start_step=0, mode='none', **kwargs)[source]

Bases: abc.ABC

Quantized Layer interface

CONFIG_ATTRIBUTES = ['weight_bits', 'start_step', 'mode']
REPR_ATTRIBUTES = ['mode', 'weight_bits']
classmethod from_config(*args, config=None, **kwargs)[source]

Initialize quantized layer from config


Implement forward method to be used while evaluating


handle transition between quantized model and simulated quantization


Implement forward method to be used while training

class nlp_architect.nn.torch.quantization.QuantizedLinear(*args, activation_bits=8, requantize_output=True, ema_decay=0.9999, **kwargs)[source]

Bases: nlp_architect.nn.torch.quantization.QuantizedLayer, torch.nn.modules.linear.Linear

Linear layer with quantization aware training capability

CONFIG_ATTRIBUTES = ['weight_bits', 'start_step', 'mode', 'activation_bits', 'requantize_output', 'ema_decay']
REPR_ATTRIBUTES = ['mode', 'weight_bits', 'activation_bits', 'accumulation_bits', 'ema_decay', 'requantize_output']

Simulate quantized inference. quantize input and perform calculation with only integer numbers. This function should only be used while doing inference


fake quantized forward, fake quantizes weights and activations, learn quantization ranges if quantization mode is EMA. This function should only be used while training


Calculate the maximum symmetric quantized value according to number of bits

nlp_architect.nn.torch.quantization.dequantize(input, scale)[source]

linear dequantization according to some scale

nlp_architect.nn.torch.quantization.get_dynamic_scale(x, bits, with_grad=False)[source]

Calculate dynamic scale for quantization from input by taking the maximum absolute value from x and number of bits

nlp_architect.nn.torch.quantization.get_scale(bits, threshold)[source]

Calculate scale for quantization according to some constant and number of bits

nlp_architect.nn.torch.quantization.quantize(input, scale, bits)[source]

Do linear quantization to input according to a scale and number of bits

Module contents

nlp_architect.nn.torch.set_seed(seed, n_gpus=None)[source]

set and return seed


Setup backend according to selected backend and detected configuration